--- title: fastai U-Net keywords: fastai sidebar: home_sidebar ---
{% raw %}
{% endraw %} {% raw %}
%reload_ext autoreload
%autoreload 2
%matplotlib inline
{% endraw %} {% raw %}
{% endraw %} {% raw %}
import sys
sys.path.append('..')
from superres.datasets import *
from superres.databunch import *
{% endraw %} {% raw %}
seed = 8610
random.seed(seed)
np.random.seed(seed)
{% endraw %}

DataBunch

{% raw %}
train_hr = div2k_train_hr_crop_256
{% endraw %} {% raw %}
in_size = 256
out_size = 256
scale = 4
bs = 10
{% endraw %} {% raw %}
data = create_sr_databunch(train_hr, in_size=in_size, out_size=out_size, scale=scale, bs=bs, seed=seed)
print(data)
data.show_batch()
ImageDataBunch;

Train: LabelList (25245 items)
x: ImageImageList
Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256)
y: ImageImageList
Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256)
Path: /home/jovyan/notebook/datasets/DIV2K/DIV2K_train_HR_crop/256;

Valid: LabelList (6311 items)
x: ImageImageList
Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256)
y: ImageImageList
Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256)
Path: /home/jovyan/notebook/datasets/DIV2K/DIV2K_train_HR_crop/256;

Test: None
{% endraw %}

Training

{% raw %}
model = models.resnet34
loss_func = MSELossFlat()
metrics = [m_psnr, m_ssim]
wd = 1e-3
y_range = (-3.,3.)
model_name = 'sr_unet'
{% endraw %} {% raw %}
learn = unet_learner(data, model, wd=wd, metrics=metrics, y_range=y_range, loss_func=loss_func,
                     blur=True, norm_type=NormType.Weight, self_attention=True)
learn.path = Path('.')
{% endraw %} {% raw %}
lr_find(learn)
learn.recorder.plot(suggestion=True)
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
Min numerical gradient: 1.91E-04
Min loss divided by 10: 1.10E-01
{% endraw %} {% raw %}
lr = 1e-3
lrs = slice(lr)
epoch = 3
pct_start = 0.3
wd = 1e-3
save_fname = model_name
{% endraw %} {% raw %}
callbacks = [ShowGraph(learn), SaveModelCallback(learn, name=save_fname)]
{% endraw %} {% raw %}
learn.fit_one_cycle(epoch, lrs, pct_start=pct_start, wd=wd, callbacks=callbacks)
epoch train_loss valid_loss m_psnr m_ssim time
0 0.090948 0.057407 32.308601 0.425962 09:00
1 0.079309 0.054548 32.957603 0.435841 08:58
2 0.073125 0.053997 33.810516 0.441839 08:59
Better model found at epoch 0 with valid_loss value: 0.05740687996149063.
Better model found at epoch 1 with valid_loss value: 0.054548367857933044.
Better model found at epoch 2 with valid_loss value: 0.05399680882692337.
{% endraw %} {% raw %}
learn.show_results()
{% endraw %}

Test

{% raw %}
test_hr = set14_hr
{% endraw %} {% raw %}
il_test_x = ImageImageList.from_folder(test_hr, after_open=partial(after_open_image, size=in_size, scale=4, sizeup=True))
il_test_y = ImageImageList.from_folder(test_hr, after_open=partial(after_open_image, size=out_size))
{% endraw %} {% raw %}
_ = learn.load(save_fname)
{% endraw %} {% raw %}
sr_test(learn, il_test_x, il_test_y, model_name)
bicubic: PSNR:24.11,SSIM:0.7822
sr_unet:	 PSNR:24.98,SSIM:0.8146
{% endraw %}

Report

{% raw %}
model
<function torchvision.models.resnet.resnet34(pretrained=False, progress=True, **kwargs)>
{% endraw %} {% raw %}
learn.summary()
DynamicUnet
======================================================================
Layer (type)         Output Shape         Param #    Trainable 
======================================================================
Conv2d               [64, 128, 128]       9,408      False     
______________________________________________________________________
BatchNorm2d          [64, 128, 128]       128        True      
______________________________________________________________________
ReLU                 [64, 128, 128]       0          False     
______________________________________________________________________
MaxPool2d            [64, 64, 64]         0          False     
______________________________________________________________________
Conv2d               [64, 64, 64]         36,864     False     
______________________________________________________________________
BatchNorm2d          [64, 64, 64]         128        True      
______________________________________________________________________
ReLU                 [64, 64, 64]         0          False     
______________________________________________________________________
Conv2d               [64, 64, 64]         36,864     False     
______________________________________________________________________
BatchNorm2d          [64, 64, 64]         128        True      
______________________________________________________________________
Conv2d               [64, 64, 64]         36,864     False     
______________________________________________________________________
BatchNorm2d          [64, 64, 64]         128        True      
______________________________________________________________________
ReLU                 [64, 64, 64]         0          False     
______________________________________________________________________
Conv2d               [64, 64, 64]         36,864     False     
______________________________________________________________________
BatchNorm2d          [64, 64, 64]         128        True      
______________________________________________________________________
Conv2d               [64, 64, 64]         36,864     False     
______________________________________________________________________
BatchNorm2d          [64, 64, 64]         128        True      
______________________________________________________________________
ReLU                 [64, 64, 64]         0          False     
______________________________________________________________________
Conv2d               [64, 64, 64]         36,864     False     
______________________________________________________________________
BatchNorm2d          [64, 64, 64]         128        True      
______________________________________________________________________
Conv2d               [128, 32, 32]        73,728     False     
______________________________________________________________________
BatchNorm2d          [128, 32, 32]        256        True      
______________________________________________________________________
ReLU                 [128, 32, 32]        0          False     
______________________________________________________________________
Conv2d               [128, 32, 32]        147,456    False     
______________________________________________________________________
BatchNorm2d          [128, 32, 32]        256        True      
______________________________________________________________________
Conv2d               [128, 32, 32]        8,192      False     
______________________________________________________________________
BatchNorm2d          [128, 32, 32]        256        True      
______________________________________________________________________
Conv2d               [128, 32, 32]        147,456    False     
______________________________________________________________________
BatchNorm2d          [128, 32, 32]        256        True      
______________________________________________________________________
ReLU                 [128, 32, 32]        0          False     
______________________________________________________________________
Conv2d               [128, 32, 32]        147,456    False     
______________________________________________________________________
BatchNorm2d          [128, 32, 32]        256        True      
______________________________________________________________________
Conv2d               [128, 32, 32]        147,456    False     
______________________________________________________________________
BatchNorm2d          [128, 32, 32]        256        True      
______________________________________________________________________
ReLU                 [128, 32, 32]        0          False     
______________________________________________________________________
Conv2d               [128, 32, 32]        147,456    False     
______________________________________________________________________
BatchNorm2d          [128, 32, 32]        256        True      
______________________________________________________________________
Conv2d               [128, 32, 32]        147,456    False     
______________________________________________________________________
BatchNorm2d          [128, 32, 32]        256        True      
______________________________________________________________________
ReLU                 [128, 32, 32]        0          False     
______________________________________________________________________
Conv2d               [128, 32, 32]        147,456    False     
______________________________________________________________________
BatchNorm2d          [128, 32, 32]        256        True      
______________________________________________________________________
Conv2d               [256, 16, 16]        294,912    False     
______________________________________________________________________
BatchNorm2d          [256, 16, 16]        512        True      
______________________________________________________________________
ReLU                 [256, 16, 16]        0          False     
______________________________________________________________________
Conv2d               [256, 16, 16]        589,824    False     
______________________________________________________________________
BatchNorm2d          [256, 16, 16]        512        True      
______________________________________________________________________
Conv2d               [256, 16, 16]        32,768     False     
______________________________________________________________________
BatchNorm2d          [256, 16, 16]        512        True      
______________________________________________________________________
Conv2d               [256, 16, 16]        589,824    False     
______________________________________________________________________
BatchNorm2d          [256, 16, 16]        512        True      
______________________________________________________________________
ReLU                 [256, 16, 16]        0          False     
______________________________________________________________________
Conv2d               [256, 16, 16]        589,824    False     
______________________________________________________________________
BatchNorm2d          [256, 16, 16]        512        True      
______________________________________________________________________
Conv2d               [256, 16, 16]        589,824    False     
______________________________________________________________________
BatchNorm2d          [256, 16, 16]        512        True      
______________________________________________________________________
ReLU                 [256, 16, 16]        0          False     
______________________________________________________________________
Conv2d               [256, 16, 16]        589,824    False     
______________________________________________________________________
BatchNorm2d          [256, 16, 16]        512        True      
______________________________________________________________________
Conv2d               [256, 16, 16]        589,824    False     
______________________________________________________________________
BatchNorm2d          [256, 16, 16]        512        True      
______________________________________________________________________
ReLU                 [256, 16, 16]        0          False     
______________________________________________________________________
Conv2d               [256, 16, 16]        589,824    False     
______________________________________________________________________
BatchNorm2d          [256, 16, 16]        512        True      
______________________________________________________________________
Conv2d               [256, 16, 16]        589,824    False     
______________________________________________________________________
BatchNorm2d          [256, 16, 16]        512        True      
______________________________________________________________________
ReLU                 [256, 16, 16]        0          False     
______________________________________________________________________
Conv2d               [256, 16, 16]        589,824    False     
______________________________________________________________________
BatchNorm2d          [256, 16, 16]        512        True      
______________________________________________________________________
Conv2d               [256, 16, 16]        589,824    False     
______________________________________________________________________
BatchNorm2d          [256, 16, 16]        512        True      
______________________________________________________________________
ReLU                 [256, 16, 16]        0          False     
______________________________________________________________________
Conv2d               [256, 16, 16]        589,824    False     
______________________________________________________________________
BatchNorm2d          [256, 16, 16]        512        True      
______________________________________________________________________
Conv2d               [512, 8, 8]          1,179,648  False     
______________________________________________________________________
BatchNorm2d          [512, 8, 8]          1,024      True      
______________________________________________________________________
ReLU                 [512, 8, 8]          0          False     
______________________________________________________________________
Conv2d               [512, 8, 8]          2,359,296  False     
______________________________________________________________________
BatchNorm2d          [512, 8, 8]          1,024      True      
______________________________________________________________________
Conv2d               [512, 8, 8]          131,072    False     
______________________________________________________________________
BatchNorm2d          [512, 8, 8]          1,024      True      
______________________________________________________________________
Conv2d               [512, 8, 8]          2,359,296  False     
______________________________________________________________________
BatchNorm2d          [512, 8, 8]          1,024      True      
______________________________________________________________________
ReLU                 [512, 8, 8]          0          False     
______________________________________________________________________
Conv2d               [512, 8, 8]          2,359,296  False     
______________________________________________________________________
BatchNorm2d          [512, 8, 8]          1,024      True      
______________________________________________________________________
Conv2d               [512, 8, 8]          2,359,296  False     
______________________________________________________________________
BatchNorm2d          [512, 8, 8]          1,024      True      
______________________________________________________________________
ReLU                 [512, 8, 8]          0          False     
______________________________________________________________________
Conv2d               [512, 8, 8]          2,359,296  False     
______________________________________________________________________
BatchNorm2d          [512, 8, 8]          1,024      True      
______________________________________________________________________
BatchNorm2d          [512, 8, 8]          1,024      True      
______________________________________________________________________
ReLU                 [512, 8, 8]          0          False     
______________________________________________________________________
Conv2d               [1024, 8, 8]         4,719,616  True      
______________________________________________________________________
ReLU                 [1024, 8, 8]         0          False     
______________________________________________________________________
Conv2d               [512, 8, 8]          4,719,104  True      
______________________________________________________________________
ReLU                 [512, 8, 8]          0          False     
______________________________________________________________________
Conv2d               [1024, 8, 8]         525,312    True      
______________________________________________________________________
PixelShuffle         [256, 16, 16]        0          False     
______________________________________________________________________
ReplicationPad2d     [256, 17, 17]        0          False     
______________________________________________________________________
AvgPool2d            [256, 16, 16]        0          False     
______________________________________________________________________
ReLU                 [1024, 8, 8]         0          False     
______________________________________________________________________
BatchNorm2d          [256, 16, 16]        512        True      
______________________________________________________________________
Conv2d               [512, 16, 16]        2,359,808  True      
______________________________________________________________________
ReLU                 [512, 16, 16]        0          False     
______________________________________________________________________
Conv2d               [512, 16, 16]        2,359,808  True      
______________________________________________________________________
ReLU                 [512, 16, 16]        0          False     
______________________________________________________________________
ReLU                 [512, 16, 16]        0          False     
______________________________________________________________________
Conv2d               [1024, 16, 16]       525,312    True      
______________________________________________________________________
PixelShuffle         [256, 32, 32]        0          False     
______________________________________________________________________
ReplicationPad2d     [256, 33, 33]        0          False     
______________________________________________________________________
AvgPool2d            [256, 32, 32]        0          False     
______________________________________________________________________
ReLU                 [1024, 16, 16]       0          False     
______________________________________________________________________
BatchNorm2d          [128, 32, 32]        256        True      
______________________________________________________________________
Conv2d               [384, 32, 32]        1,327,488  True      
______________________________________________________________________
ReLU                 [384, 32, 32]        0          False     
______________________________________________________________________
Conv2d               [384, 32, 32]        1,327,488  True      
______________________________________________________________________
ReLU                 [384, 32, 32]        0          False     
______________________________________________________________________
Conv1d               [48, 1024]           18,432     True      
______________________________________________________________________
Conv1d               [48, 1024]           18,432     True      
______________________________________________________________________
Conv1d               [384, 1024]          147,456    True      
______________________________________________________________________
ReLU                 [384, 32, 32]        0          False     
______________________________________________________________________
Conv2d               [768, 32, 32]        295,680    True      
______________________________________________________________________
PixelShuffle         [192, 64, 64]        0          False     
______________________________________________________________________
ReplicationPad2d     [192, 65, 65]        0          False     
______________________________________________________________________
AvgPool2d            [192, 64, 64]        0          False     
______________________________________________________________________
ReLU                 [768, 32, 32]        0          False     
______________________________________________________________________
BatchNorm2d          [64, 64, 64]         128        True      
______________________________________________________________________
Conv2d               [256, 64, 64]        590,080    True      
______________________________________________________________________
ReLU                 [256, 64, 64]        0          False     
______________________________________________________________________
Conv2d               [256, 64, 64]        590,080    True      
______________________________________________________________________
ReLU                 [256, 64, 64]        0          False     
______________________________________________________________________
ReLU                 [256, 64, 64]        0          False     
______________________________________________________________________
Conv2d               [512, 64, 64]        131,584    True      
______________________________________________________________________
PixelShuffle         [128, 128, 128]      0          False     
______________________________________________________________________
ReplicationPad2d     [128, 129, 129]      0          False     
______________________________________________________________________
AvgPool2d            [128, 128, 128]      0          False     
______________________________________________________________________
ReLU                 [512, 64, 64]        0          False     
______________________________________________________________________
BatchNorm2d          [64, 128, 128]       128        True      
______________________________________________________________________
Conv2d               [96, 128, 128]       165,984    True      
______________________________________________________________________
ReLU                 [96, 128, 128]       0          False     
______________________________________________________________________
Conv2d               [96, 128, 128]       83,040     True      
______________________________________________________________________
ReLU                 [96, 128, 128]       0          False     
______________________________________________________________________
ReLU                 [192, 128, 128]      0          False     
______________________________________________________________________
Conv2d               [384, 128, 128]      37,248     True      
______________________________________________________________________
PixelShuffle         [96, 256, 256]       0          False     
______________________________________________________________________
ReLU                 [384, 128, 128]      0          False     
______________________________________________________________________
MergeLayer           [99, 256, 256]       0          False     
______________________________________________________________________
Conv2d               [99, 256, 256]       88,308     True      
______________________________________________________________________
ReLU                 [99, 256, 256]       0          False     
______________________________________________________________________
Conv2d               [99, 256, 256]       88,308     True      
______________________________________________________________________
ReLU                 [99, 256, 256]       0          False     
______________________________________________________________________
MergeLayer           [99, 256, 256]       0          False     
______________________________________________________________________
Conv2d               [3, 256, 256]        300        True      
______________________________________________________________________
SigmoidRange         [3, 256, 256]        0          False     
______________________________________________________________________

Total params: 41,405,588
Total trainable params: 20,137,940
Total non-trainable params: 21,267,648
Optimized with 'torch.optim.adam.Adam', betas=(0.9, 0.99)
Using true weight decay as discussed in https://www.fast.ai/2018/07/02/adam-weight-decay/ 
Loss function : FlattenedLoss
======================================================================
Callbacks functions applied 
{% endraw %}